Sobreajuste y entrenamiento de modelos

Fecha de publicación

4 de noviembre de 2024

Objetivo del manual

Paquetes a utilizar en este manual:

Código
# instalar/cargar paquetes

sketchy::load_packages(
  c("ggplot2", 
    "viridis", 
    "caret"
   )
  )
Loading required package: caret
Loading required package: lattice

Funciones personalizadas a utilizar en este manual:

Código
rmse_lm <- function(model, data, response = "y"){
  pred <- predict(model, newdata = data)
  sqrt(mean((data[, response] - pred)^2))
}

r2_lm <- function(model, data, response = "y"){
  summary(model)$r.squared
  }

1 Entrenamiento de modelos

El objetivo de los modelos de aprendizaje estadístico es el de obtener patrones de los datos de entrenamiento para predecir o inferir correctamente los patrones en la población original de donde provienen esos datos de entrenamiento. Es decir, la clave esta en obtener patrones generales que sean extrapolables a nuevos datos. La idea principal del entrenamiento es ajustar el modelo a los datos de entrenamiento para aprender patrones que se puedan generalizar a datos nuevos. Sin embargo, parte de este proceso implica estrategias para evitar tanto el sobreajuste como el subajuste.

1.1 Sobreajuste

El sobreajuste ocurre cuando un modelo se ajusta demasiado bien a los datos de entrenamiento, capturando tanto los patrones verdaderos como el ruido o variaciones aleatorias de los datos. Como resultado, el modelo funciona bien en el conjunto de entrenamiento, pero tiene un rendimiento deficiente en nuevos datos (pobre capacidad de generalización). El sobreajuste se refiere a cuando modelo está tan ajustado a los datos de entrenamiento que afecta su capacidad de generalización. El sobreajuste se produce cuando un sistema de aprendizaje automático se entrena demasiado o con datos (levemente) sesgados, que hace que el algoritmo aprenda patrones que no son generales. Aprende características especificas pero no los patrones generales, el concepto.

Una forma de evaluar la capacidad de generalización de un modelo es mediante la división de los datos en dos conjuntos: entrenamiento y prueba. El modelo se ajusta a los datos de entrenamiento y se evalúa en los datos de prueba. El sobreajuste se puede detectar cuando el error en los datos de prueba es mucho mayor que el error en los datos de entrenamiento.

Los modelos más complejos tienden a sobreajustar más que lo modelos más simples. Además, ante un mismo modelo, a menor cantidad de datos es más posible que ese modelo se sobreajuste. Existen varios métodos para evaluar cuándo un modelo está sobreajustando. En la simulacion que se muestra a continuación, se ajusta un modelo de regresión lineal con diferentes cantidades de predictores (p). Se calcula el error cuadrático medio en los datos de entrenamiento y en los datos de prueba.

Código
# Generar datos sintéticos
set.seed(6)
n <- 200  # Número de observaciones
p <- 5   # Número de predictores

# Crear variables predictoras aleatorias
datos <- as.data.frame(matrix(rnorm(n * p), n, p))
colnames(datos) <- paste0("x", 1:p)

# Crear variable de respuesta con una combinación de algunas variables
datos$y <-
  3 * datos$x1 - 2 * datos$x2 + 1 * datos$x3 + rnorm(n, 0, 2)

# Dividir en conjunto de entrenamiento y prueba
entren_indice <- createDataPartition(datos$y, p = 0.7, list = FALSE)
datos_entren <- datos[entren_indice,]
datos_prueba <- datos[-entren_indice,]


resultados_lista <- lapply(1:(ncol(datos) - 1), function(z) {
  # Ajustar modelo de regresión lineal
  modelo_entren <-
    lm(y ~ ., data = datos_entren[, c("y", paste0("x", 1:z))])
  
  # Ajustar modelo de regresión lineal
  modelo_prueba <-
    lm(y ~ ., data = datos_prueba[, c("y", paste0("x", 1:z))])
  
  
  # Calcular raiz del error cuadrático medio
  rmse_entren <- rmse_lm(modelo_entren, datos_entren, response = "y")
  rmse_prueba <- rmse_lm(modelo_prueba, datos_prueba, response = "y")
  r2_entren <- r2_lm(modelo_entren)
  r2_prueba <- r2_lm(modelo_prueba)
  
  resultados <-
    data.frame(
      n_predictores = z,
      Tipo = c("Prueba", "Entrenamiendo"),
      rmse = c(rmse_prueba, rmse_entren),
      r2 = c(r2_prueba, r2_entren)
    )
  
  return(resultados)
  
})

resultados_df <- do.call(rbind, resultados_lista)

# grafico de resultados
ggplot(resultados_df, aes(x = n_predictores, y = rmse, color = Tipo)) +
  geom_line(lwd = 2) +
  scale_x_continuous(breaks = seq(0, p, 1)) +
  scale_color_viridis_d(end = 0.9) +
  labs(x = "Número de predictores", y = "RMSE") + 
  theme(legend.background = element_rect(fill = "#fff3cd"),
                          # legend in the midlle of the graph
                          legend.position = c(0.5, 0.7), 
    panel.background = element_rect(fill = "#fff3cd"),
    plot.background = element_rect(fill = "#fff3cd", colour = NA))

Podemos ver como la raíz del error cuadrático medio en los datos de prueba aumenta a medida que se aumenta el número de predictores. En cambio, el error en los datos de entrenamiento disminuye a medida que se aumenta el número de predictores. Esto es un claro indicio de sobreajuste.

El mismo patrón se observa en el coeficiente de determinación (R2). A medida que se aumenta el número de predictores, el R2 en los datos de entrenamiento aumenta, mientras que en los datos de prueba disminuye.

Código
ggplot(resultados_df, aes(x = n_predictores, y = r2, color = Tipo)) +
  geom_line(lwd = 2) +
  scale_x_continuous(breaks = seq(0, p, 1)) +
  scale_color_viridis_d(end = 0.9) +
  labs(x = "Número de predictores",
       y = bquote('Coeficiente de determinación' ~ R ^ 2)) +
  theme(
    legend.background = element_rect(fill = "#fff3cd"),
    legend.position = c(0.5, 0.7),
    panel.background = element_rect(fill = "#fff3cd"),
    plot.background = element_rect(fill = "#fff3cd", colour = NA)
  )

El aparente aumento en el poder predictivo del modelo en los datos de entrenamiento no se traduce en una mejora en los datos de prueba. Esto ocurre incluso en ausencia de variables informacion en los predictores que ayude a explicar la variable respuesta. Esto se vuelve aun mas evidente cuando se simulamos datos donde ninguno de los predictores esta asociado a la variable respuesta:

Código
# Generar datos sintéticos
set.seed(6)
n <- 200  # Número de observaciones
p <- 20   # Número de predictores

# Crear variables predictoras aleatorias
datos_aleat <- as.data.frame(matrix(rnorm(n * p), n, p))
colnames(datos_aleat) <- paste0("x", 1:p)

# Crear variable de respuesta
datos_aleat$y <- rnorm(n, 0, 2)

# Dividir en conjunto de entrenamiento y prueba
entren_indice <- createDataPartition(datos_aleat$y, p = 0.7, list = FALSE)
datos_aleat_entren <- datos_aleat[entren_indice,]
datos_aleat_prueba <- datos_aleat[-entren_indice,]


resultados_lista <- lapply(1:(ncol(datos_aleat) - 1), function(z) {
  # Ajustar modelo de regresión lineal
  modelo_entren <-
    lm(y ~ ., data = datos_aleat_entren[, c("y", paste0("x", 1:z))])
  
  # Ajustar modelo de regresión lineal
  modelo_prueba <-
    lm(y ~ ., data = datos_aleat_prueba[, c("y", paste0("x", 1:z))])
  
  
  # Calcular raiz del error cuadrático medio
  rmse_entren <- rmse_lm(modelo_entren, datos_aleat_entren, response = "y")
  rmse_prueba <- rmse_lm(modelo_prueba, datos_aleat_prueba, response = "y")
  r2_entren <- r2_lm(modelo_entren)
  r2_prueba <- r2_lm(modelo_prueba)
  
  resultados <-
    data.frame(
      n_predictores = z,
      Tipo = c("Prueba", "Entrenamiendo"),
      rmse = c(rmse_prueba, rmse_entren),
      r2 = c(r2_prueba, r2_entren)
    )
  
  return(resultados)
  
})

resultados_df <- do.call(rbind, resultados_lista)

ggplot(resultados_df, aes(x = n_predictores, y = r2, color = Tipo)) +
  geom_line(lwd = 2) +
  scale_x_continuous(breaks = seq(0, p, 1)) +
  scale_color_viridis_d(end = 0.9) +
  labs(x = "Número de predictores",
       y = bquote('Coeficiente de determinación' ~ R ^ 2)) +
  theme(
    legend.background = element_rect(fill = "#fff3cd"),
    legend.position = c(0.5, 0.7),
    panel.background = element_rect(fill = "#fff3cd"),
    plot.background = element_rect(fill = "#fff3cd", colour = NA)
  )

2 AIC

Datos con 3 predictores asociados:

Código
resultados_lista <- lapply(1:(ncol(datos) - 1), function(z) {
  # Ajustar modelo de regresión lineal
  modelo <-
    lm(y ~ ., data = datos[, c("y", paste0("x", 1:z))])

  # modelos nulos
  modelo_nulo <- lm(y ~ 1, data = datos[, c("y", paste0("x", 1:z))])

  # Calcular AIC
  aics <- AIC(modelo, modelo_nulo)
 aics$delta.aic <- aics$AIC - min(aics$AIC)

  resultados <-
    data.frame(
      n_predictores = z,
      delta_AIC = c(aics["modelo_nulo", "delta.aic"])
                    )
  
  return(resultados)
  
})

resultados_df <- do.call(rbind, resultados_lista)

ggplot(resultados_df, aes(x = n_predictores, y = delta_AIC)) +
  geom_line(lwd = 2, color = viridis(1)) +
  scale_x_continuous(breaks = seq(0, p, 1)) +
  labs(x = "Número de predictores",
       y = "Delta AIC") + 
  theme(legend.position = c(0.5, 0.7)) +
      scale_y_reverse()

Datos donde ningun predictor está asociado:

Código
resultados_lista <- lapply(1:(ncol(datos_aleat) - 1), function(z) {
  # Ajustar modelo de regresión lineal
  modelo <-
    lm(y ~ ., data = datos_aleat[, c("y", paste0("x", 1:z))])

  # modelos nulos
  modelo_nulo <- lm(y ~ 1, data = datos_aleat[, c("y", paste0("x", 1:z))])

  # Calcular AIC
  aics <- AIC(modelo, modelo_nulo)
 aics$delta.aic <- aics$AIC - min(aics$AIC)

  resultados <-
    data.frame(
      n_predictores = z,
      delta_AIC = c(aics["modelo_nulo", "delta.aic"])
                    )
  
  return(resultados)
  
})

resultados_df <- do.call(rbind, resultados_lista)

ggplot(resultados_df, aes(x = n_predictores, y = delta_AIC)) +
  geom_line(lwd = 2, color = viridis(1)) +
  scale_x_continuous(breaks = seq(0, p, 1)) +
  labs(x = "Número de predictores",
       y = "Delta AIC") + 
  theme(legend.position = c(0.5, 0.7)) +
      scale_y_reverse()

2.1 Ejercicio 1

1.1


Información de la sesión

R version 4.4.1 (2024-06-14)
Platform: x86_64-pc-linux-gnu
Running under: Ubuntu 22.04.4 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0 
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=es_CR.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=es_CR.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=es_CR.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=es_CR.UTF-8 LC_IDENTIFICATION=C       

time zone: America/Costa_Rica
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] caret_6.0-94      lattice_0.22-6    nnet_7.3-19       viridis_0.6.5    
[5] viridisLite_0.4.2 ggplot2_3.5.1     knitr_1.48       

loaded via a namespace (and not attached):
 [1] gtable_0.3.5         xfun_0.48            htmlwidgets_1.6.4   
 [4] recipes_1.0.10       remotes_2.5.0        vctrs_0.6.5         
 [7] tools_4.4.1          generics_0.1.3       stats4_4.4.1        
[10] parallel_4.4.1       tibble_3.2.1         fansi_1.0.6         
[13] ModelMetrics_1.2.2.2 pkgconfig_2.0.3      Matrix_1.7-0        
[16] data.table_1.15.4    lifecycle_1.0.4      farver_2.1.2        
[19] compiler_4.4.1       stringr_1.5.1        munsell_0.5.1       
[22] codetools_0.2-20     sketchy_1.0.3        htmltools_0.5.8.1   
[25] class_7.3-22         yaml_2.3.10          prodlim_2024.06.25  
[28] pillar_1.9.0         crayon_1.5.3         MASS_7.3-61         
[31] gower_1.0.1          iterators_1.0.14     rpart_4.1.23        
[34] foreach_1.5.2        nlme_3.1-165         parallelly_1.38.0   
[37] lava_1.8.0           tidyselect_1.2.1     packrat_0.9.2       
[40] digest_0.6.37        stringi_1.8.4        future_1.34.0       
[43] reshape2_1.4.4       purrr_1.0.2          dplyr_1.1.4         
[46] listenv_0.9.1        labeling_0.4.3       splines_4.4.1       
[49] fastmap_1.2.0        grid_4.4.1           colorspace_2.1-1    
[52] cli_3.6.3            magrittr_2.0.3       survival_3.7-0      
[55] utf8_1.2.4           future.apply_1.11.2  withr_3.0.1         
[58] scales_1.3.0         timechange_0.3.0     lubridate_1.9.3     
[61] rmarkdown_2.28       globals_0.16.3       timeDate_4032.109   
[64] gridExtra_2.3        evaluate_1.0.1       hardhat_1.4.0       
[67] rlang_1.1.4          Rcpp_1.0.13          glue_1.8.0          
[70] xaringanExtra_0.8.0  pROC_1.18.5          ipred_0.9-14        
[73] rstudioapi_0.16.0    jsonlite_1.8.9       R6_2.5.1            
[76] plyr_1.8.9